import os
import re
import json
import random
import torch
from rich import print
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from policies import BasePolicy
from policies.utils.replay_buffer import ReplayBuffer
from policies.utils.reflection import Reflection, get_obs_message
import policies.prompts as prompts

MODELS = [
    "meta-llama/Meta-Llama-3-70B-Instruct",
]
"""
Three concepts of the policy:
(1) message_history
(2) prompt
(3) replay_buffer
"""
class CommLlamaPolicy(BasePolicy):
    def __init__(self,
                 model="meta-llama/Meta-Llama-3-8B-Instruct",
                 agent_id="",
                 temperature=0.2,
                 adapter=None,
                 device="cuda",
                 comm_only=False,
                 control_only=False,
                 skip_frames=0,
                 batch_size=1,
                 logdir=None,
                 is_focal=False
    ):
        # basic policy setup
        self.agent_id = agent_id
        self.decision_frequency = 10 # frames, decision_frequency / frame_rate = seconds per decision
        self.comm_only = comm_only # indicate whether the agent is communication only
        self.control_only = control_only # indicate whether the agent is control only
        self.skip_frames = skip_frames
        self.is_focal = is_focal

        # set up language model
        self.model = AutoModelForCausalLM.from_pretrained(
                                                          model,
                                                          torch_dtype=torch.bfloat16
                                                          )
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        if adapter is not None:
            self.model.load_adapter(adapter)
        self.terminators = [
                            self.tokenizer.eos_token_id,
                            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                            ]
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        self.temperature = temperature

        # basic prompts
        self.instruction = prompts.get_instruction(comm_only)
        self.common_sense = prompts.get_common_sense(comm_only, control_only)
        self.message_history = []

        # set up learning module 
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        self.reflection = Reflection(model=self.model,
                                     tokenizer=self.tokenizer,
                                     device=self.device,
                                     temperature=self.temperature,
                                     comm_only=self.comm_only,
                                     control_only=self.control_only,
                                     is_focal=self.is_focal
                                     )
        self.iteration = 0
        self.experience = None
        self.current_observation = None
        
        # episodic metric setup
        self.episode_return = 0
        self.step_count = 0
        self.prev_action = None
        self.plan = None
        self.logdir = logdir
        self.learned_knowledge = ""
        self.cooperative_knowledge = ""

    def reset(self):
        """
        Reset the agent when a new episode starts
        Warning: you should not reset any learning modules here!
        """
        self.step_count = 0
        self.episode_return = 0
        self.current_observation = None
        self.prev_action = None
        self.plan = None
        self.message_history = []
        print("Resetting the agent")

    def observe(self, obs, reward, terminated, truncated, info):
        if obs is None:
            self.current_observation = None
            return
        self.current_observation = obs
        self.episode_return += reward

    def act(self):
        """
        Generate actions using language model
        """
        self.step_count += 1
        if self.step_count <= self.skip_frames: 
            return {"command":"go", "message":""}
        if self.step_count % self.decision_frequency == 1:
            response = self.prompting()
            action = self.parse_action(response)
            self.prev_action = action
            print(self.message_history)
        action = self.prev_action
        return action

    def chat(self, prompt, json_format=False):
        prompt = self.tokenizer.apply_chat_template(prompt,
                                                    add_generation_prompt=True,
                                                    return_tensors='pt').to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(prompt,
                                          max_length=8192,
                                          temperature=self.temperature,
                                          eos_token_id=self.terminators
                                          )
        response = self.tokenizer.decode(outputs[0][prompt.shape[-1]:], skip_special_tokens=True)
        return response

    def prompting(self)->str:
        """
        Generate the chain of thought prompting
        """
        response = {}

        # System message step
        self.message_history = [{"role":"system", "content":self.instruction}]
        self.message_history.append({"role":"system", "content":self.common_sense})

        # Previous knowledge step
        if self.cooperative_knowledge:
            self.message_history.append({"role":"assistant", "content":self.cooperative_knowledge})
        if self.learned_knowledge:
            self.message_history.append({"role":"assistant", "content":self.learned_knowledge})

        # Prompt Observation
        observation = self.current_observation
        if observation is None:
            obs_message = "No observation"
        else:
            obs_message = get_obs_message(observation)
        self.message_history.append({"role":"user", "content": obs_message})

        # Reasoning step
        cot_prompt1 = prompts.get_cot_prompt_1(self.comm_only, self.control_only)
        self.message_history.append({"role":"user", "content":cot_prompt1})
        response['reasoning'] = self.chat(self.message_history)
        self.message_history.append({"role":"assistant", "content":response['reasoning']})

        # Decision step
        cot_prompt2 = prompts.get_cot_prompt_2(self.comm_only, self.control_only)
        self.message_history.append({"role":"user", "content":cot_prompt2})
        response['action'] = self.chat(self.message_history, json_format=True)

        # Record the action generated by model
        self.message_history.append({"role":"assistant", "content":response['action']})
        return response

    def parse_action(self, response):
        """
        Parse the action into a dictionary
        """
        action = {"command":"go", "reasoning":response['reasoning']}
        try:
            response = re.findall(r"\{[^*]*\}", response['action'])[0]
            if response:
                response = json.loads(response)
                if "command" in response:
                    action["command"] = response["command"]
                if "message" in response:
                    action["message"] = response["message"]
        except:
            print("Error in parsing the response", response)
        action = dict(sorted(action.items(), key=lambda item: item[0]))
        return action

    def get_episode_return(self)->float:
        return self.episode_return

    def store_transition(self, transition)->None:
        """
        Store the transition in replay buffer for learning
        """
        if transition.obs is None:
            return
        self.replay_buffer.add(transition)

    def learn(self)->None:
        """
        Learn from the experince
        """
        self.iteration += 1
        batch = self.replay_buffer.sample_batch(batch_size=self.batch_size)
        updated_knowledge = self.reflection.reflect(batch)
        self.learned_knowledge = updated_knowledge
        self.save(self.iteration)
        # Clear replay buffer
        self.replay_buffer.clear()

    def debrief(self, collective_knowledges, learned_knowledges, agent_in_debrief, batch_size, last_round=False):
        """
        Debrief the agent after the training
        """
        # debrief batch size could be smaller than default batch size
        batch = self.replay_buffer.sample_batch(batch_size=batch_size)
        # ask agent to reflect on the batch and propose cooperation and ego knowledge
        collective_knowledge, learned_knowledge = self.reflection.debrief(batch,
                                                                          collective_knowledges,
                                                                          learned_knowledges,
                                                                          agent_in_debrief,
                                                                          last_round=last_round)
        return collective_knowledge, learned_knowledge

    def internalize(self):
        """
        Internalize the knowledge from the debriefing
        """
        if self.reflection.cooperative_knowledge:
            self.cooperative_knowledge = self.reflection.cooperative_knowledge
        if self.reflection.learned_knowledge:
            self.learned_knowledge = self.reflection.learned_knowledge

    def save(self, ckpt_num):
        """
        Save the knowledge for the future training
        """
        assert self.logdir is not None
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir, exist_ok=True)
        json.dump(
            {"knowledge": self.learned_knowledge,
             "cooperative_knowledge": self.cooperative_knowledge},
            open(os.path.join(self.logdir, f"ckpt-{ckpt_num}.json"), "w")
        )
    
    def load(self, ckpt_num):
        """
        load the knowledge from the previous training
        """
        if self.logdir is None:
            return
        if ckpt_num == -1:
            # load the latest checkpoint
            ckpt_num = max([int(ckpt.split("-")[-1].split(".")[0]) for ckpt in os.listdir(self.logdir)])
        knowledge = json.load(open(os.path.join(
                                                self.logdir,
                                                f"ckpt-{ckpt_num}.json"
                                        ), "r"))
        print(f"Loading checkpoint {ckpt_num}")
        if "knowledge" in knowledge:
            self.learned_knowledge = knowledge["knowledge"]
        if "cooperative_knowledge" in knowledge:
            self.cooperative_knowledge = knowledge["cooperative_knowledge"]
